Skip to content

Conversation

tgymnich
Copy link
Member

@tgymnich tgymnich commented Jun 30, 2025

  • we'll need to add a scalar version of amdgpu.scaled_{ext,trunc}_packed in the future.

cc @umangyadav

fixes iree-org/iree#20821

@llvmbot
Copy link
Member

llvmbot commented Jun 30, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir-gpu

Author: Tim Gymnich (tgymnich)

Changes

Patch is 52.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146372.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+165)
  • (added) mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir (+553)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 3596b3235a631..22cd4703c6005 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -15,11 +15,17 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
@@ -32,6 +38,7 @@ using namespace mlir::amdgpu;
 namespace {
 // Define commonly used chipsets versions for convenience.
 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
+constexpr Chipset kGfx950 = Chipset(9, 5, 0);
 
 struct ArithToAMDGPUConversionPass final
     : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
@@ -73,6 +80,28 @@ struct TruncfToFloat16RewritePattern final
                                 PatternRewriter &rewriter) const override;
 };
 
+struct ScalingExtFRewritePattern final : OpRewritePattern<arith::ScalingExtFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  Chipset chipset;
+  ScalingExtFRewritePattern(MLIRContext *ctx, Chipset chipset)
+      : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
+  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
+struct ScalingTruncFRewritePattern final : OpRewritePattern<arith::ScalingTruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  Chipset chipset;
+  ScalingTruncFRewritePattern(MLIRContext *ctx, Chipset chipset)
+      : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
+
+  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // end namespace
 
 static bool isSupportedF8(Type elementType, Chipset chipset) {
@@ -395,6 +424,137 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   return success();
 }
 
+static Value getOriginalVectorValue(Value value) {
+  Value current = value;
+  while (Operation *definingOp = current.getDefiningOp()) {
+    bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
+      .Case<vector::ShapeCastOp>(
+        [&current](auto op) {
+          current = op.getSource();
+          return true;
+        })
+      .Case<vector::BroadcastOp>(
+        [&current](auto op) {
+          current = op.getSource();
+          return false;
+        })
+      .Case<vector::SplatOp>(
+        [&current](auto op) {
+          current = op.getInput();
+          return false;
+        })
+      .Default([](Operation *) {
+        return false;
+      });
+
+    if (!skipOp) {
+      break;
+    }
+  }
+  return current;
+}
+
+LogicalResult
+ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
+                                           PatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  constexpr const int64_t opWidth = 2;
+
+  Value in = op.getIn();
+  Value scale = op.getScale();
+  Value out = op.getOut();
+
+  Type f32 = rewriter.getF32Type();
+  Type inType = getElementTypeOrSelf(in);
+  Type scaleType = getElementTypeOrSelf(scale);
+  Type outType = getElementTypeOrSelf(out);
+  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
+  VectorType inVecType = dyn_cast<VectorType>(in.getType());
+  VectorType outVecType = dyn_cast<VectorType>(out.getType());
+
+  if (outVecType && outVecType.isScalable())
+    return failure();
+
+  Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
+  if (scaleType.getIntOrFloatBitWidth() < 32)
+    scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
+  else if (scaleType.getIntOrFloatBitWidth() > 32)
+    scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);
+
+  VectorType extScaleResultType = VectorType::get(opWidth, outType);
+
+  if (!outVecType) {
+      Value inCast = rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
+      Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(loc, extScaleResultType, inCast, scale, 0);
+      scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
+      return success();
+  }
+
+  Value origScale = getOriginalVectorValue(scale);
+  Type origScaleType = origScale.getType();
+  VectorType origScaleVecType = isa<VectorType>(origScaleType) ? cast<VectorType>(origScaleType) : VectorType::get(1, origScaleType);
+  
+  ArrayRef<int64_t> originalScaleShape = origScaleVecType.getShape();
+  ArrayRef<int64_t> inShape = inVecType.getShape();
+
+  SmallVector<int64_t> paddedScaleShape(originalScaleShape);
+  paddedScaleShape.insert(paddedScaleShape.end(), inShape.size() - originalScaleShape.size(),
+                       1);
+
+  auto ratio = computeShapeRatio(inShape, paddedScaleShape);
+  if (!ratio)
+    return failure();
+
+  const int64_t blockSize = computeProduct(*ratio);
+
+  Value zero = rewriter.create<arith::ConstantOp>(
+      loc, outType, rewriter.getFloatAttr(outType, 0.0));
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
+
+  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, *ratio)) {
+    SmallVector<int64_t> strides(offsets.size(), 1);
+    Value block = rewriter.create<vector::ExtractStridedSliceOp>(
+        loc, in, offsets, *ratio, strides);
+    VectorType block1DType = VectorType::get(blockSize, inType);
+    Value block1D =
+        rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
+    Value uniformScale =
+        rewriter.create<vector::ExtractOp>(loc, scale, offsets);
+
+    VectorType blockResultType = VectorType::get(blockSize, outType);
+    Value blockResult =
+        rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);
+
+    for (int64_t i = 0, sliceWidth = opWidth - blockSize % opWidth;
+         i < blockSize;
+         i += sliceWidth, sliceWidth = opWidth - blockSize % opWidth) {
+      Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, block1D, i, sliceWidth, 1);
+      Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
+          loc, extScaleResultType, slice, uniformScale, 0);
+      if (sliceWidth != opWidth)
+        scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, scaleExt, 0, sliceWidth, 1);
+      blockResult = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, scaleExt, blockResult, i, 1);
+    }
+
+    VectorType resultType = VectorType::get(*ratio, outType);
+    Value cast = rewriter.create<vector::ShapeCastOp>(loc, resultType,
+                                                            blockResult);
+    result = rewriter.create<vector::InsertStridedSliceOp>(
+        loc, cast, result, offsets, strides);
+  }
+
+  rewriter.replaceOp(op, result);
+
+  return success();
+}
+
+LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const {
+  return success();
+}
+
 void mlir::arith::populateArithToAMDGPUConversionPatterns(
     RewritePatternSet &patterns, bool convertFP8Arithmetic,
     bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
@@ -406,6 +566,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
   }
   if (allowPackedF16Rtz)
     patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
+
+  if (chipset >= kGfx950) {
+    patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), chipset);
+    patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), chipset);
+  }
 }
 
 void ArithToAMDGPUConversionPass::runOnOperation() {
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir b/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
new file mode 100644
index 0000000000000..1669926ae48cc
--- /dev/null
+++ b/mlir/test/Conversion/ArithToAMDGPU/scale_ext.mlir
@@ -0,0 +1,553 @@
+// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
+
+// CHECK-LABEL: @conversion_f8_fallback
+// CHECK:         [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT:    [[SCALE_EXT:%.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT:    [[IN_SLICE_00:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_00:%.+]] = vector.shape_cast [[IN_SLICE_00]]
+// CHECK-NEXT:    [[SCALE_SCALAR_00:%.+]] = vector.extract [[SCALE_EXT]][0, 0]
+// CHECK-NEXT:    [[PACKED_00:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_00]][0], [[SCALE_SCALAR_00]]
+// CHECK-NEXT:    [[OUT_SLICE_00:%.+]] = vector.extract_strided_slice [[PACKED_00]]
+// CHECK-NEXT:    [[OUT_SCALAR_00:%.+]] = vector.shape_cast [[OUT_SLICE_00]]
+// CHECK-NEXT:    [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_00]], [[CST]]
+// CHECK-NEXT:    [[IN_SLICE_01:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_01:%.+]] = vector.shape_cast [[IN_SLICE_01]]
+// CHECK-NEXT:    [[SCALE_SCALAR_01:%.+]] = vector.extract [[SCALE_EXT]][0, 1]
+// CHECK-NEXT:    [[PACKED_01:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_01]][0], [[SCALE_SCALAR_01]]
+// CHECK-NEXT:    [[OUT_SLICE_01:%.+]] = vector.extract_strided_slice [[PACKED_01]]
+// CHECK-NEXT:    [[OUT_SCALAR_01:%.+]] = vector.shape_cast [[OUT_SLICE_01]]
+// CHECK-NEXT:    [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_01]], [[ACC_A]]
+// CHECK-NEXT:    [[IN_SLICE_10:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT:    [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0]
+// CHECK-NEXT:    [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT:    [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT:    [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT:    [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT:    [[IN_SLICE_11:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT:    [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 1]
+// CHECK-NEXT:    [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT:    [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT:    [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT:    [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT:    return [[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f8_fallback(%in: vector<2x2xf8E5M2>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+    %ext = arith.scaling_extf %in, %scale : vector<2x2xf8E5M2>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+    return %ext : vector<2x2xf32>
+}
+
+// CHECK-LABEL: @conversion_f4_fallback
+// CHECK:         [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
+// CHECK-NEXT:    [[SCALE_EXT:%.+]] = arith.extf %arg1 : vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+// CHECK-NEXT:    [[IN_SLICE_00:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_00:%.+]] = vector.shape_cast [[IN_SLICE_00]]
+// CHECK-NEXT:    [[SCALE_SCALAR_00:%.+]] = vector.extract [[SCALE_EXT]][0, 0]
+// CHECK-NEXT:    [[PACKED_00:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_00]][0], [[SCALE_SCALAR_00]]
+// CHECK-NEXT:    [[OUT_SLICE_00:%.+]] = vector.extract_strided_slice [[PACKED_00]]
+// CHECK-NEXT:    [[OUT_SCALAR_00:%.+]] = vector.shape_cast [[OUT_SLICE_00]]
+// CHECK-NEXT:    [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_00]], [[CST]]
+// CHECK-NEXT:    [[IN_SLICE_01:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_01:%.+]] = vector.shape_cast [[IN_SLICE_01]]
+// CHECK-NEXT:    [[SCALE_SCALAR_01:%.+]] = vector.extract [[SCALE_EXT]][0, 1]
+// CHECK-NEXT:    [[PACKED_01:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_01]][0], [[SCALE_SCALAR_01]]
+// CHECK-NEXT:    [[OUT_SLICE_01:%.+]] = vector.extract_strided_slice [[PACKED_01]]
+// CHECK-NEXT:    [[OUT_SCALAR_01:%.+]] = vector.shape_cast [[OUT_SLICE_01]]
+// CHECK-NEXT:    [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_01]], [[ACC_A]]
+// CHECK-NEXT:    [[IN_SLICE_10:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_10:%.+]] = vector.shape_cast [[IN_SLICE_10]]
+// CHECK-NEXT:    [[SCALE_SCALAR_10:%.+]] = vector.extract [[SCALE_EXT]][1, 0]
+// CHECK-NEXT:    [[PACKED_10:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_10]][0], [[SCALE_SCALAR_10]]
+// CHECK-NEXT:    [[OUT_SLICE_10:%.+]] = vector.extract_strided_slice [[PACKED_10]]
+// CHECK-NEXT:    [[OUT_SCALAR_10:%.+]] = vector.shape_cast [[OUT_SLICE_10]]
+// CHECK-NEXT:    [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_10]], [[ACC_B]]
+// CHECK-NEXT:    [[IN_SLICE_11:%.+]] = vector.extract_strided_slice %arg0
+// CHECK-NEXT:    [[IN_SCALAR_11:%.+]] = vector.shape_cast [[IN_SLICE_11]]
+// CHECK-NEXT:    [[SCALE_SCALAR_11:%.+]] = vector.extract [[SCALE_EXT]][1, 1]
+// CHECK-NEXT:    [[PACKED_11:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_11]][0], [[SCALE_SCALAR_11]]
+// CHECK-NEXT:    [[OUT_SLICE_11:%.+]] = vector.extract_strided_slice [[PACKED_11]]
+// CHECK-NEXT:    [[OUT_SCALAR_11:%.+]] = vector.shape_cast [[OUT_SLICE_11]]
+// CHECK-NEXT:    [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_11]], [[ACC_A]]
+// CHECK-NEXT:    return [[ACC_B]] : vector<2x2xf32>
+func.func @conversion_f4_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<2x2xf8E8M0FNU>) -> vector<2x2xf32> {
+    %ext = arith.scaling_extf %in, %scale : vector<2x2xf4E2M1FN>, vector<2x2xf8E8M0FNU> to vector<2x2xf32>
+    return %ext : vector<2x2xf32>
+}
+
+
+// CHECK-LABEL: @conversion_broadcast
+// CHECK:       [[CST:%.+]] = arith.constant dense<0.000000e+00> : vector<8x2x4xf32>
+// CHECK-NEXT:  [[BCAST:%.+]] = vector.broadcast %arg1 : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
+// CHECK-NEXT:  [[IN_CAST:%.+]] = vector.shape_cast %arg0 : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
+// CHECK-NEXT:  [[SCALE_CAST:%.+]] = vector.shape_cast [[BCAST]] : vector<4x8x2xf8E8M0FNU> to vector<8x2x4xf8E8M0FNU>
+// CHECK-NEXT:  [[SCALE_EXT:%.+]] = arith.extf [[SCALE_CAST]] : vector<8x2x4xf8E8M0FNU> to vector<8x2x4xf32>
+// CHECK-NEXT:  [[IN_SLICE_0:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_0:%.+]] = vector.shape_cast [[IN_SLICE_0]]
+// CHECK-NEXT:  [[SCALE_SCALAR_0:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 0]
+// CHECK-NEXT:  [[PACKED_0:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_0]][0], [[SCALE_SCALAR_0]]
+// CHECK-NEXT:  [[OUT_SLICE_0:%.+]] = vector.extract_strided_slice [[PACKED_0]]
+// CHECK-NEXT:  [[OUT_SCALAR_0:%.+]] = vector.shape_cast [[OUT_SLICE_0]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_0]], [[CST]]
+// CHECK-NEXT:  [[IN_SLICE_1:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_1:%.+]] = vector.shape_cast [[IN_SLICE_1]]
+// CHECK-NEXT:  [[SCALE_SCALAR_1:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 1]
+// CHECK-NEXT:  [[PACKED_1:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_1]][0], [[SCALE_SCALAR_1]]
+// CHECK-NEXT:  [[OUT_SLICE_1:%.+]] = vector.extract_strided_slice [[PACKED_1]]
+// CHECK-NEXT:  [[OUT_SCALAR_1:%.+]] = vector.shape_cast [[OUT_SLICE_1]]
+// CHECK-NEXT:  [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_1]], [[ACC_A]]
+// CHECK-NEXT:  [[IN_SLICE_2:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_2:%.+]] = vector.shape_cast [[IN_SLICE_2]]
+// CHECK-NEXT:  [[SCALE_SCALAR_2:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 2]
+// CHECK-NEXT:  [[PACKED_2:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_2]][0], [[SCALE_SCALAR_2]]
+// CHECK-NEXT:  [[OUT_SLICE_2:%.+]] = vector.extract_strided_slice [[PACKED_2]]
+// CHECK-NEXT:  [[OUT_SCALAR_2:%.+]] = vector.shape_cast [[OUT_SLICE_2]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_2]], [[ACC_B]]
+// CHECK-NEXT:  [[IN_SLICE_3:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_3:%.+]] = vector.shape_cast [[IN_SLICE_3]]
+// CHECK-NEXT:  [[SCALE_SCALAR_3:%.+]] = vector.extract [[SCALE_EXT]][0, 0, 3]
+// CHECK-NEXT:  [[PACKED_3:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_3]][0], [[SCALE_SCALAR_3]]
+// CHECK-NEXT:  [[OUT_SLICE_3:%.+]] = vector.extract_strided_slice [[PACKED_3]]
+// CHECK-NEXT:  [[OUT_SCALAR_3:%.+]] = vector.shape_cast [[OUT_SLICE_3]]
+// CHECK-NEXT:  [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_3]], [[ACC_A]]
+// CHECK-NEXT:  [[IN_SLICE_4:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_4:%.+]] = vector.shape_cast [[IN_SLICE_4]]
+// CHECK-NEXT:  [[SCALE_SCALAR_4:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 0]
+// CHECK-NEXT:  [[PACKED_4:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_4]][0], [[SCALE_SCALAR_4]]
+// CHECK-NEXT:  [[OUT_SLICE_4:%.+]] = vector.extract_strided_slice [[PACKED_4]]
+// CHECK-NEXT:  [[OUT_SCALAR_4:%.+]] = vector.shape_cast [[OUT_SLICE_4]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_4]], [[ACC_B]]
+// CHECK-NEXT:  [[IN_SLICE_5:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_5:%.+]] = vector.shape_cast [[IN_SLICE_5]]
+// CHECK-NEXT:  [[SCALE_SCALAR_5:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 1]
+// CHECK-NEXT:  [[PACKED_5:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_5]][0], [[SCALE_SCALAR_5]]
+// CHECK-NEXT:  [[OUT_SLICE_5:%.+]] = vector.extract_strided_slice [[PACKED_5]]
+// CHECK-NEXT:  [[OUT_SCALAR_5:%.+]] = vector.shape_cast [[OUT_SLICE_5]]
+// CHECK-NEXT:  [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_5]], [[ACC_A]]
+// CHECK-NEXT:  [[IN_SLICE_6:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_6:%.+]] = vector.shape_cast [[IN_SLICE_6]]
+// CHECK-NEXT:  [[SCALE_SCALAR_6:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 2]
+// CHECK-NEXT:  [[PACKED_6:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_6]][0], [[SCALE_SCALAR_6]]
+// CHECK-NEXT:  [[OUT_SLICE_6:%.+]] = vector.extract_strided_slice [[PACKED_6]]
+// CHECK-NEXT:  [[OUT_SCALAR_6:%.+]] = vector.shape_cast [[OUT_SLICE_6]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_6]], [[ACC_B]]
+// CHECK-NEXT:  [[IN_SLICE_7:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_7:%.+]] = vector.shape_cast [[IN_SLICE_7]]
+// CHECK-NEXT:  [[SCALE_SCALAR_7:%.+]] = vector.extract [[SCALE_EXT]][0, 1, 3]
+// CHECK-NEXT:  [[PACKED_7:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_7]][0], [[SCALE_SCALAR_7]]
+// CHECK-NEXT:  [[OUT_SLICE_7:%.+]] = vector.extract_strided_slice [[PACKED_7]]
+// CHECK-NEXT:  [[OUT_SCALAR_7:%.+]] = vector.shape_cast [[OUT_SLICE_7]]
+// CHECK-NEXT:  [[ACC_B:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_7]], [[ACC_A]]
+// CHECK-NEXT:  [[IN_SLICE_8:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_8:%.+]] = vector.shape_cast [[IN_SLICE_8]]
+// CHECK-NEXT:  [[SCALE_SCALAR_8:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 0]
+// CHECK-NEXT:  [[PACKED_8:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_8]][0], [[SCALE_SCALAR_8]]
+// CHECK-NEXT:  [[OUT_SLICE_8:%.+]] = vector.extract_strided_slice [[PACKED_8]]
+// CHECK-NEXT:  [[OUT_SCALAR_8:%.+]] = vector.shape_cast [[OUT_SLICE_8]]
+// CHECK-NEXT:  [[ACC_A:%.+]] = vector.insert_strided_slice [[OUT_SCALAR_8]], [[ACC_B]]
+// CHECK-NEXT:  [[IN_SLICE_9:%.+]] = vector.extract_strided_slice [[IN_CAST]]
+// CHECK-NEXT:  [[IN_SCALAR_9:%.+]] = vector.shape_cast [[IN_SLICE_9]]
+// CHECK-NEXT:  [[SCALE_SCALAR_9:%.+]] = vector.extract [[SCALE_EXT]][1, 0, 1]
+// CHECK-NEXT:  [[PACKED_9:%.+]] = amdgpu.scaled_ext_packed [[IN_SCALAR_9]][0], [[SCALE_SCALAR_9]]
+// CHECK-NEXT:  [[OUT_SLICE_9:%.+]]...
[truncated]

@tgymnich tgymnich force-pushed the tim/arith-scaling-to-amdgpu branch from 591ac0d to 3fee107 Compare June 30, 2025 20:09
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, *ratio)) {
for (const SmallVector<int64_t> &offsets : StaticTileOffsetRange(inShape, *ratio)) {

let's not copy vectors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this vector is created in-place on the stack.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd still copy the elements. There may be copy elision but I didn't check if it is guaranteed here.

Copy link
Member Author

@tgymnich tgymnich Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The vector is moved copied either way.

@tgymnich tgymnich force-pushed the tim/arith-scaling-to-amdgpu branch from 37e519d to bad8e78 Compare July 1, 2025 11:43
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please find a better way to check the conversion broadcast that 400 lines of check-next, this is not very maintainable :)

@tgymnich tgymnich force-pushed the tim/arith-scaling-to-amdgpu branch from 9ae4dbb to a6fbc23 Compare July 2, 2025 16:46
@tgymnich tgymnich force-pushed the tim/arith-scaling-to-amdgpu branch from a6fbc23 to b76a3e8 Compare July 8, 2025 13:10
Copy link

github-actions bot commented Jul 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@umangyadav umangyadav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix formatting otherwise LGTM. Thanks for working on this.

@tgymnich tgymnich force-pushed the tim/arith-scaling-to-amdgpu branch from b50a495 to efc6194 Compare July 8, 2025 14:44
@tgymnich tgymnich merged commit 6f291cb into llvm:main Jul 8, 2025
9 checks passed
@tgymnich tgymnich changed the title [mlir][amdgpu] Add conversion from arith.scaling_extf to amdgpu [mlir][amdgpu] Add conversion from arith.scaling_{extf,truncf} to amdgpu Jul 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Lower arith.scaling_extf and arith.scaling_truncf to amdgpu.scaled_ext_packed and amdgpu.packed_scaled_trunc
4 participants